import os
import numpy as np
import h5py
from sympy.utilities.lambdify import lambdify
from ar_sim.continuum_action import ContinuumAction

class PathIntegralSampler:
    """
    Metropolis–Hastings sampler for path-integral over field configurations.
    """
    def __init__(self,
                 n_vals: np.ndarray,
                 pivot_params: dict,
                 sigma: float = 1.0,
                 q_init: np.ndarray = None,
                 sweep: int = 100000,
                 step_size: float = 0.1,
                 temperature: float = 1.0,
                 save_interval: int = 1000,
                 output: str = "results/samples.h5"):
        # Setup continuum action and lambdified action function
        self.cont_action = ContinuumAction(n_vals, pivot_params, sigma)
        self.q_syms = self.cont_action.q
        self.action_expr = self.cont_action.discrete_action()
        self.action_fn = lambdify(self.q_syms, self.action_expr, "numpy")

        # Sampler parameters
        self.sweep = sweep
        self.step_size = step_size
        self.temperature = temperature
        self.save_interval = save_interval
        self.output = output

        # Initialize field configuration q
        self.N = len(n_vals)
        if q_init is None:
            self.q = np.zeros(self.N)
        else:
            self.q = q_init.copy()

    def run(self):
        # Prepare results directory
        out_dir = os.path.dirname(self.output)
        if out_dir and not os.path.exists(out_dir):
            os.makedirs(out_dir)

        # Pre-allocate sample storage
        num_samples = self.sweep // self.save_interval + 1
        samples = np.zeros((num_samples, self.N))

        # Initial action value
        S_current = float(self.action_fn(*self.q))
        samples[0] = self.q
        idx = 1

        # Sampling loop
        for step in range(1, self.sweep + 1):
            # Propose new configuration
            q_new = self.q + np.random.normal(scale=self.step_size, size=self.N)
            S_new = float(self.action_fn(*q_new))

            # Metropolis criterion
            if np.random.rand() < np.exp(-(S_new - S_current) / self.temperature):
                self.q = q_new
                S_current = S_new

            # Save sample at intervals
            if step % self.save_interval == 0:
                samples[idx] = self.q
                idx += 1

        # Write samples to HDF5
        with h5py.File(self.output, "w") as f:
            f.create_dataset("samples", data=samples)

        print(f"Saved {num_samples} samples to {self.output}")

if __name__ == "__main__":
    # Example usage with default parameters
    from ar_sim.common.fractal_fits import load_D_values

    n_vals, D_vals, sigma_vals = load_D_values()
    pivot_params = {"a": -1.35, "b": 3.70, "D_vals": D_vals}
    sampler = PathIntegralSampler(n_vals, pivot_params)
    sampler.run()
